import torch 
from transformer_model import BertConfig, FixOut


if __name__ == "__main__":
    t1 = torch.rand((5, 10, 64))

    config = BertConfig(img_size=(8, 8))
    segment_model = FixOut(config)

    print(segment_model(t1).shape)

